In [17]:
# ============================================================
# OASIS-2 – Baseline Dementia Classification (Combined Features)
#
# Features used (NO CDR as predictor):
# ['age', 'mmse', 'educ', 'ses', 'nwbv', 'etiv', 'asf', 'sex_enc']
#
# Label:
# dementia_label = 1 if Demented or Converted, 0 if Nondemented
#
# Models:
# - Logistic Regression
# - Random Forest
#
# Outputs:
# - Descriptive statistics
# - ROC / PR curves (RF, test split)
# - 5-fold CV metrics for RF (main results)
# - SHAP summaries for RF (combined model)
# - Summary text file: oasis_dementia_combined_summary.txt
# ============================================================
import os, warnings
warnings.filterwarnings("ignore")
# Install shap if needed
try:
import shap # noqa
except ImportError:
import sys, subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "shap"])
import shap # noqa
# Fix deprecated numpy aliases for shap/sklearn internals
import numpy as np
if not hasattr(np, "bool"):
np.bool = bool
if not hasattr(np, "float"):
np.float = float
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from IPython.display import display
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
roc_auc_score,
average_precision_score,
roc_curve,
precision_recall_curve,
confusion_matrix,
accuracy_score,
precision_score,
recall_score,
)
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
sns.set(style="whitegrid", context="talk")
RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
# ------------------------------------------------------------
# 1. Load dataset
# ------------------------------------------------------------
file_name = "oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx"
if not os.path.exists(file_name):
raise FileNotFoundError(
f"Could not find file: {file_name}. "
"Upload it in the Colab Files pane (left sidebar)."
)
df = pd.read_excel(file_name)
print("Raw shape:", df.shape)
print("Columns:", df.columns.tolist())
display(df.head())
# ------------------------------------------------------------
# 2. Basic cleaning and type handling
# ------------------------------------------------------------
# Standardize column names
df.columns = [c.strip().lower().replace(" ", "_").replace("/", "_") for c in df.columns]
# Rename to consistent names
df = df.rename(
columns={
"subject_id": "subject_id",
"mri_id": "mri_id",
"group": "group",
"visit": "visit",
"mr_delay": "mr_delay_days",
"m_f": "sex",
"hand": "hand",
"age": "age",
"educ": "educ",
"ses": "ses",
"mmse": "mmse",
"cdr": "cdr",
"etiv": "etiv",
"nwbv": "nwbv",
"asf": "asf",
}
)
expected_cols = [
"subject_id", "mri_id", "group", "visit", "mr_delay_days",
"sex", "hand", "age", "educ", "ses", "mmse", "cdr", "etiv", "nwbv", "asf"
]
missing = [c for c in expected_cols if c not in df.columns]
if missing:
raise ValueError(f"Missing expected columns: {missing}")
df["subject_id"] = df["subject_id"].astype(str)
for c in ["visit", "mr_delay_days", "age", "educ", "ses", "mmse", "cdr", "etiv", "nwbv", "asf"]:
df[c] = pd.to_numeric(df[c], errors="coerce")
df["mr_delay_years"] = df["mr_delay_days"] / 365.25
# ------------------------------------------------------------
# 3. Missingness inspection
# ------------------------------------------------------------
print("\nFraction of missing values per column:")
missing_frac = df.isna().mean().sort_values(ascending=False)
display(missing_frac)
plt.figure(figsize=(9, 4))
missing_frac.plot(kind="bar")
plt.ylabel("Fraction missing")
plt.title("Missing data by column")
plt.tight_layout()
plt.show()
# ------------------------------------------------------------
# 4. Baseline-level dataset (MR_DELAY == 0)
# Label: 1 = Demented or Converted, 0 = Nondemented
# ------------------------------------------------------------
baseline = df[df["mr_delay_days"] == 0].copy()
print("\nBaseline shape:", baseline.shape)
print("Baseline groups:")
display(baseline["group"].value_counts())
label_map = {"Nondemented": 0, "Demented": 1, "Converted": 1}
baseline["dementia_label"] = baseline["group"].map(label_map)
baseline = baseline.dropna(subset=["dementia_label"])
baseline["dementia_label"] = baseline["dementia_label"].astype(int)
print("\nDementia label distribution at baseline (0 = nondemented, 1 = demented/converted):")
display(baseline["dementia_label"].value_counts())
plt.figure(figsize=(4, 4))
sns.countplot(data=baseline, x="dementia_label")
plt.xticks([0, 1], ["Nondemented", "Demented/Converted"])
plt.title("Baseline dementia status")
plt.tight_layout()
plt.show()
# ------------------------------------------------------------
# 5. Longitudinal summary per subject (descriptive only)
# ------------------------------------------------------------
long_summary = (
df[df["subject_id"].isin(baseline["subject_id"])]
.groupby("subject_id")
.agg(
n_visits=("visit", "max"),
max_delay_days=("mr_delay_days", "max"),
max_cdr=("cdr", "max"),
min_mmse=("mmse", "min"),
)
.reset_index()
)
long_summary["followup_years"] = long_summary["max_delay_days"] / 365.25
baseline = baseline.merge(long_summary, on="subject_id", how="left")
print("\nFollow-up years summary (all baseline subjects):")
display(baseline["followup_years"].describe())
plt.figure(figsize=(6, 4))
sns.histplot(baseline["followup_years"], bins=15, kde=True)
plt.xlabel("Follow-up years")
plt.title("Distribution of follow-up duration")
plt.tight_layout()
plt.show()
# ------------------------------------------------------------
# 6. Prepare combined features + label (NO CDR)
# ------------------------------------------------------------
# Encode sex
sex_map = {"F": 0, "M": 1, "Female": 0, "Male": 1}
baseline["sex_enc"] = baseline["sex"].map(sex_map)
label_col = "dementia_label"
y = baseline[label_col].astype(int).values
combined_features = ["age", "mmse", "educ", "ses", "nwbv", "etiv", "asf", "sex_enc"]
print("\nFinal modeling dataset size:", baseline.shape[0])
print("Combined feature set:", combined_features)
print("Nondemented:", int((y == 0).sum()),
"| Demented/Converted:", int((y == 1).sum()))
# Correlation heatmap
corr_df = baseline[combined_features + [label_col]].copy()
plt.figure(figsize=(10, 8))
corr = corr_df.corr()
sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", vmin=-1, vmax=1)
plt.title("Correlation matrix – Combined features + dementia status")
plt.tight_layout()
plt.show()
# KDEs for key features
plt.figure(figsize=(14, 4))
for i, col in enumerate(["age", "mmse", "nwbv"], start=1):
if col not in baseline.columns:
continue
plt.subplot(1, 3, i)
sns.kdeplot(
data=baseline,
x=col,
hue=label_col,
common_norm=False,
fill=True,
alpha=0.4,
)
plt.title(col)
plt.xlabel(col)
plt.suptitle("Baseline distributions by dementia status", y=1.05)
plt.tight_layout()
plt.show()
# ------------------------------------------------------------
# 7. Helper functions for thresholds and metrics
# ------------------------------------------------------------
def compute_metrics_at_threshold(name, y_true, proba, thr):
"""Compute metrics at a given threshold thr."""
y_pred = (proba >= thr).astype(int)
auc = roc_auc_score(y_true, proba)
pr_auc = average_precision_score(y_true, proba)
acc = accuracy_score(y_true, y_pred)
sens = recall_score(y_true, y_pred)
prec = precision_score(y_true, y_pred, zero_division=0)
# Always force 2x2 confusion by specifying labels=[0,1]
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
spec = tn / (tn + fp) if (tn + fp) > 0 else np.nan
print(f"\n=== {name} (threshold = {thr:.3f}) ===")
print(f"ROC-AUC: {auc:.3f}")
print(f"PR-AUC: {pr_auc:.3f}")
print(f"Accuracy: {acc:.3f}")
print(f"Sensitivity: {sens:.3f}")
print(f"Specificity: {spec:.3f}")
print(f"Precision (PPV): {prec:.3f}")
print(f"Confusion: TP={tp}, FP={fp}, TN={tn}, FN={fn}")
return {
"threshold": thr,
"auc": auc,
"pr_auc": pr_auc,
"acc": acc,
"sens": sens,
"spec": spec,
"prec": prec,
"tp": tp,
"fp": fp,
"tn": tn,
"fn": fn,
}
def choose_threshold_roc_optimal(y_true, proba):
"""Choose threshold closest to (FPR=0, TPR=1) on the ROC curve."""
fpr, tpr, thr = roc_curve(y_true, proba)
# distance^2 to (0,1)
dist2 = (fpr ** 2) + ((1 - tpr) ** 2)
idx = int(np.argmin(dist2))
return thr[idx]
# ------------------------------------------------------------
# 8. Simple train/test split evaluation (for intuition)
# ------------------------------------------------------------
X = baseline[combined_features].values
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, stratify=y, random_state=RANDOM_STATE
)
print("\nTrain size:", X_train.shape[0], "| Test size:", X_test.shape[0])
imputer = SimpleImputer(strategy="median")
X_train_imp = imputer.fit_transform(X_train)
X_test_imp = imputer.transform(X_test)
# Logistic Regression
log_reg = LogisticRegression(max_iter=1000, random_state=RANDOM_STATE)
log_reg.fit(X_train_imp, y_train)
proba_lr = log_reg.predict_proba(X_test_imp)[:, 1]
# RF
rf_model = RandomForestClassifier(
n_estimators=300,
max_depth=None,
min_samples_split=4,
min_samples_leaf=2,
random_state=RANDOM_STATE,
class_weight="balanced",
)
rf_model.fit(X_train_imp, y_train)
proba_rf = rf_model.predict_proba(X_test_imp)[:, 1]
# Thresholds for each model
for model_name, proba in [("Logistic Regression", proba_lr), ("Random Forest", proba_rf)]:
# Fixed 0.5
_ = compute_metrics_at_threshold(
f"Combined – {model_name} (fixed 0.5)",
y_test,
proba,
thr=0.5,
)
# ROC-optimal
thr_opt = choose_threshold_roc_optimal(y_test, proba)
_ = compute_metrics_at_threshold(
f"Combined – {model_name} (ROC-optimal)",
y_test,
proba,
thr=thr_opt,
)
# ROC & PR curves for RF on test set
plt.figure(figsize=(6, 5))
fpr, tpr, _ = roc_curve(y_test, proba_rf)
auc_rf = roc_auc_score(y_test, proba_rf)
plt.plot(fpr, tpr, label=f"RF (AUC={auc_rf:.2f})")
plt.plot([0, 1], [0, 1], "k--", alpha=0.4)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC curve – Random Forest (combined features, test set)")
plt.legend()
plt.tight_layout()
plt.show()
plt.figure(figsize=(6, 5))
prec, rec, _ = precision_recall_curve(y_test, proba_rf)
ap_rf = average_precision_score(y_test, proba_rf)
plt.plot(rec, prec, label=f"RF (AP={ap_rf:.2f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision–Recall curve – Random Forest (combined, test set)")
plt.legend()
plt.tight_layout()
plt.show()
# ------------------------------------------------------------
# 9. 5-fold cross-validation for RF (main performance)
# ------------------------------------------------------------
print("\n\n===== 5-fold cross-validation – Random Forest (combined features) =====")
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
cv_results = []
for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X, y), start=1):
X_tr, X_val = X[train_idx], X[val_idx]
y_tr, y_val = y[train_idx], y[val_idx]
imputer_cv = SimpleImputer(strategy="median")
X_tr_imp = imputer_cv.fit_transform(X_tr)
X_val_imp = imputer_cv.transform(X_val)
rf_cv = RandomForestClassifier(
n_estimators=300,
max_depth=None,
min_samples_split=4,
min_samples_leaf=2,
random_state=RANDOM_STATE,
class_weight="balanced",
)
rf_cv.fit(X_tr_imp, y_tr)
proba_val = rf_cv.predict_proba(X_val_imp)[:, 1]
# AUC / PR-AUC
auc_val = roc_auc_score(y_val, proba_val)
pr_auc_val = average_precision_score(y_val, proba_val)
# Thresholds: 0.5 and ROC-optimal
for thr_name, thr in [
("fixed_0.5", 0.5),
("roc_opt", choose_threshold_roc_optimal(y_val, proba_val)),
]:
metrics = compute_metrics_at_threshold(
f"Fold {fold_idx} – RF (combined, {thr_name})",
y_val,
proba_val,
thr=thr,
)
metrics["fold"] = fold_idx
metrics["thr_name"] = thr_name
metrics["auc"] = auc_val # override with CV-specific auc
metrics["pr_auc"] = pr_auc_val
cv_results.append(metrics)
cv_df = pd.DataFrame(cv_results)
display(cv_df.head())
print("\nCross-validation summary (Random Forest, combined features):")
summary_cv = (
cv_df.groupby("thr_name")[["auc", "pr_auc", "acc", "sens", "spec", "prec"]]
.agg(["mean", "std"])
)
display(summary_cv)
# ------------------------------------------------------------
# 10. SHAP explainability for final RF model (trained on full data)
# ------------------------------------------------------------
print("\nTraining final RF on full combined dataset for SHAP explanations...")
imputer_full = SimpleImputer(strategy="median")
X_full_imp = imputer_full.fit_transform(X)
rf_full = RandomForestClassifier(
n_estimators=300,
max_depth=None,
min_samples_split=4,
min_samples_leaf=2,
random_state=RANDOM_STATE,
class_weight="balanced",
)
rf_full.fit(X_full_imp, y)
try:
shap.initjs()
# Background sample
bg_size = min(100, X_full_imp.shape[0])
rng = np.random.RandomState(RANDOM_STATE)
idx_bg = rng.choice(X_full_imp.shape[0], size=bg_size, replace=False)
background = X_full_imp[idx_bg]
explainer = shap.TreeExplainer(rf_full, data=background)
shap_values = explainer.shap_values(X_full_imp)
if isinstance(shap_values, list):
shap_values_class1 = shap_values[1]
else:
shap_values_class1 = shap_values
# Beeswarm summary plot
shap.summary_plot(
shap_values_class1,
X_full_imp,
feature_names=combined_features,
show=False,
)
plt.title("SHAP Beeswarm – RF (combined features)")
plt.tight_layout()
plt.show()
# Bar plot
shap.summary_plot(
shap_values_class1,
X_full_imp,
feature_names=combined_features,
plot_type="bar",
show=False,
)
plt.title("Mean |SHAP| Feature Importance – RF (combined features)")
plt.tight_layout()
plt.show()
except Exception as e:
print("SHAP plotting encountered an error (plots skipped):", e)
# ------------------------------------------------------------
# 11. Write summary text file
# ------------------------------------------------------------
summary_lines = []
def add_line(s=""):
summary_lines.append(s)
N = baseline.shape[0]
n_dem = int(y.sum())
add_line("OASIS-2 Baseline Dementia Classification – Combined Features (No CDR as Predictor)")
add_line("==============================================================================")
add_line(f"Total baseline subjects: {N}")
add_line(f"Demented/Converted: {n_dem} ({n_dem / N * 100:.1f}%)")
add_line(f"Nondemented: {N - n_dem} ({(N - n_dem) / N * 100:.1f}%)")
add_line("")
add_line("Combined feature set:")
add_line(", ".join(combined_features))
add_line("")
add_line("Random Forest – 5-fold cross-validation (combined features)")
add_line("Metrics reported as mean ± std over folds.")
add_line("Two thresholds considered: fixed 0.5 and ROC-optimal per fold.")
add_line("")
for thr_name in ["fixed_0.5", "roc_opt"]:
sub = cv_df[cv_df["thr_name"] == thr_name]
if sub.empty:
continue
auc_mean, auc_std = sub["auc"].mean(), sub["auc"].std()
pr_mean, pr_std = sub["pr_auc"].mean(), sub["pr_auc"].std()
acc_mean, acc_std = sub["acc"].mean(), sub["acc"].std()
sens_mean, sens_std = sub["sens"].mean(), sub["sens"].std()
spec_mean, spec_std = sub["spec"].mean(), sub["spec"].std()
prec_mean, prec_std = sub["prec"].mean(), sub["prec"].std()
add_line(f"Threshold strategy: {thr_name}")
add_line(f" AUC: {auc_mean:.3f} ± {auc_std:.3f}")
add_line(f" PR-AUC: {pr_mean:.3f} ± {pr_std:.3f}")
add_line(f" Acc: {acc_mean:.3f} ± {acc_std:.3f}")
add_line(f" Sens: {sens_mean:.3f} ± {sens_std:.3f}")
add_line(f" Spec: {spec_mean:.3f} ± {spec_std:.3f}")
add_line(f" PPV: {prec_mean:.3f} ± {prec_std:.3f}")
add_line("")
summary_path = "oasis_dementia_combined_summary.txt"
with open(summary_path, "w") as f:
f.write("\n".join(summary_lines))
print(f"\nSummary written to {summary_path}")
print("\nAnalysis complete.")
Raw shape: (373, 15) Columns: ['Subject ID', 'MRI ID', 'Group', 'Visit', 'MR Delay', 'M/F', 'Hand', 'Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF']
| Subject ID | MRI ID | Group | Visit | MR Delay | M/F | Hand | Age | EDUC | SES | MMSE | CDR | eTIV | nWBV | ASF | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | OAS2_0001 | OAS2_0001_MR1 | Nondemented | 1 | 0 | M | R | 87 | 14 | 2.0 | 27.0 | 0.0 | 1986.550000 | 0.696106 | 0.883440 |
| 1 | OAS2_0001 | OAS2_0001_MR2 | Nondemented | 2 | 457 | M | R | 88 | 14 | 2.0 | 30.0 | 0.0 | 2004.479526 | 0.681062 | 0.875539 |
| 2 | OAS2_0002 | OAS2_0002_MR1 | Demented | 1 | 0 | M | R | 75 | 12 | NaN | 23.0 | 0.5 | 1678.290000 | 0.736336 | 1.045710 |
| 3 | OAS2_0002 | OAS2_0002_MR2 | Demented | 2 | 560 | M | R | 76 | 12 | NaN | 28.0 | 0.5 | 1737.620000 | 0.713402 | 1.010000 |
| 4 | OAS2_0002 | OAS2_0002_MR3 | Demented | 3 | 1895 | M | R | 80 | 12 | NaN | 22.0 | 0.5 | 1697.911134 | 0.701236 | 1.033623 |
Fraction of missing values per column:
| 0 | |
|---|---|
| ses | 0.050938 |
| mmse | 0.005362 |
| subject_id | 0.000000 |
| mri_id | 0.000000 |
| mr_delay_days | 0.000000 |
| sex | 0.000000 |
| group | 0.000000 |
| visit | 0.000000 |
| age | 0.000000 |
| hand | 0.000000 |
| educ | 0.000000 |
| cdr | 0.000000 |
| etiv | 0.000000 |
| nwbv | 0.000000 |
| asf | 0.000000 |
| mr_delay_years | 0.000000 |
Baseline shape: (150, 16) Baseline groups:
| count | |
|---|---|
| group | |
| Nondemented | 72 |
| Demented | 64 |
| Converted | 14 |
Dementia label distribution at baseline (0 = nondemented, 1 = demented/converted):
| count | |
|---|---|
| dementia_label | |
| 1 | 78 |
| 0 | 72 |
Follow-up years summary (all baseline subjects):
| followup_years | |
|---|---|
| count | 150.000000 |
| mean | 2.925521 |
| std | 1.477717 |
| min | 0.999316 |
| 25% | 1.774812 |
| 50% | 2.316222 |
| 75% | 3.919918 |
| max | 7.225188 |
Final modeling dataset size: 150 Combined feature set: ['age', 'mmse', 'educ', 'ses', 'nwbv', 'etiv', 'asf', 'sex_enc'] Nondemented: 72 | Demented/Converted: 78
Train size: 120 | Test size: 30 === Combined – Logistic Regression (fixed 0.5) (threshold = 0.500) === ROC-AUC: 0.705 PR-AUC: 0.798 Accuracy: 0.600 Sensitivity: 0.500 Specificity: 0.714 Precision (PPV): 0.667 Confusion: TP=8, FP=4, TN=10, FN=8 === Combined – Logistic Regression (ROC-optimal) (threshold = 0.438) === ROC-AUC: 0.705 PR-AUC: 0.798 Accuracy: 0.633 Sensitivity: 0.625 Specificity: 0.643 Precision (PPV): 0.667 Confusion: TP=10, FP=5, TN=9, FN=6 === Combined – Random Forest (fixed 0.5) (threshold = 0.500) === ROC-AUC: 0.714 PR-AUC: 0.812 Accuracy: 0.700 Sensitivity: 0.562 Specificity: 0.857 Precision (PPV): 0.818 Confusion: TP=9, FP=2, TN=12, FN=7 === Combined – Random Forest (ROC-optimal) (threshold = 0.645) === ROC-AUC: 0.714 PR-AUC: 0.812 Accuracy: 0.733 Sensitivity: 0.562 Specificity: 0.929 Precision (PPV): 0.900 Confusion: TP=9, FP=1, TN=13, FN=7
===== 5-fold cross-validation – Random Forest (combined features) ===== === Fold 1 – RF (combined, fixed_0.5) (threshold = 0.500) === ROC-AUC: 0.836 PR-AUC: 0.893 Accuracy: 0.833 Sensitivity: 0.800 Specificity: 0.867 Precision (PPV): 0.857 Confusion: TP=12, FP=2, TN=13, FN=3 === Fold 1 – RF (combined, roc_opt) (threshold = 0.550) === ROC-AUC: 0.836 PR-AUC: 0.893 Accuracy: 0.833 Sensitivity: 0.800 Specificity: 0.867 Precision (PPV): 0.857 Confusion: TP=12, FP=2, TN=13, FN=3 === Fold 2 – RF (combined, fixed_0.5) (threshold = 0.500) === ROC-AUC: 0.849 PR-AUC: 0.816 Accuracy: 0.800 Sensitivity: 0.800 Specificity: 0.800 Precision (PPV): 0.800 Confusion: TP=12, FP=3, TN=12, FN=3 === Fold 2 – RF (combined, roc_opt) (threshold = 0.618) === ROC-AUC: 0.849 PR-AUC: 0.816 Accuracy: 0.833 Sensitivity: 0.733 Specificity: 0.933 Precision (PPV): 0.917 Confusion: TP=11, FP=1, TN=14, FN=4 === Fold 3 – RF (combined, fixed_0.5) (threshold = 0.500) === ROC-AUC: 0.808 PR-AUC: 0.860 Accuracy: 0.733 Sensitivity: 0.562 Specificity: 0.929 Precision (PPV): 0.900 Confusion: TP=9, FP=1, TN=13, FN=7 === Fold 3 – RF (combined, roc_opt) (threshold = 0.449) === ROC-AUC: 0.808 PR-AUC: 0.860 Accuracy: 0.733 Sensitivity: 0.688 Specificity: 0.786 Precision (PPV): 0.786 Confusion: TP=11, FP=3, TN=11, FN=5 === Fold 4 – RF (combined, fixed_0.5) (threshold = 0.500) === ROC-AUC: 0.857 PR-AUC: 0.899 Accuracy: 0.700 Sensitivity: 0.688 Specificity: 0.714 Precision (PPV): 0.733 Confusion: TP=11, FP=4, TN=10, FN=5 === Fold 4 – RF (combined, roc_opt) (threshold = 0.431) === ROC-AUC: 0.857 PR-AUC: 0.899 Accuracy: 0.833 Sensitivity: 0.938 Specificity: 0.714 Precision (PPV): 0.789 Confusion: TP=15, FP=4, TN=10, FN=1 === Fold 5 – RF (combined, fixed_0.5) (threshold = 0.500) === ROC-AUC: 0.795 PR-AUC: 0.868 Accuracy: 0.767 Sensitivity: 0.625 Specificity: 0.929 Precision (PPV): 0.909 Confusion: TP=10, FP=1, TN=13, FN=6 === Fold 5 – RF (combined, roc_opt) (threshold = 0.516) === ROC-AUC: 0.795 PR-AUC: 0.868 Accuracy: 0.800 Sensitivity: 0.625 Specificity: 1.000 Precision (PPV): 1.000 Confusion: TP=10, FP=0, TN=14, FN=6
| threshold | auc | pr_auc | acc | sens | spec | prec | tp | fp | tn | fn | fold | thr_name | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 0.500000 | 0.835556 | 0.892860 | 0.833333 | 0.800000 | 0.866667 | 0.857143 | 12 | 2 | 13 | 3 | 1 | fixed_0.5 |
| 1 | 0.549905 | 0.835556 | 0.892860 | 0.833333 | 0.800000 | 0.866667 | 0.857143 | 12 | 2 | 13 | 3 | 1 | roc_opt |
| 2 | 0.500000 | 0.848889 | 0.815895 | 0.800000 | 0.800000 | 0.800000 | 0.800000 | 12 | 3 | 12 | 3 | 2 | fixed_0.5 |
| 3 | 0.618354 | 0.848889 | 0.815895 | 0.833333 | 0.733333 | 0.933333 | 0.916667 | 11 | 1 | 14 | 4 | 2 | roc_opt |
| 4 | 0.500000 | 0.808036 | 0.860377 | 0.733333 | 0.562500 | 0.928571 | 0.900000 | 9 | 1 | 13 | 7 | 3 | fixed_0.5 |
Cross-validation summary (Random Forest, combined features):
| auc | pr_auc | acc | sens | spec | prec | |||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| mean | std | mean | std | mean | std | mean | std | mean | std | mean | std | |
| thr_name | ||||||||||||
| fixed_0.5 | 0.828853 | 0.026694 | 0.867316 | 0.033049 | 0.766667 | 0.052705 | 0.695000 | 0.105549 | 0.847619 | 0.091535 | 0.839913 | 0.073561 |
| roc_opt | 0.828853 | 0.026694 | 0.867316 | 0.033049 | 0.806667 | 0.043461 | 0.756667 | 0.119628 | 0.860000 | 0.113769 | 0.869799 | 0.090597 |
Training final RF on full combined dataset for SHAP explanations...
Summary written to oasis_dementia_combined_summary.txt Analysis complete.
In [22]:
# ===========================
# MINIMAL PREP CELL FOR VISUALIZATIONS (EXCEL VERSION)
# ===========================
import pandas as pd
import numpy as np
import os
# ---- EDIT YOUR FILE NAME EXACTLY AS UPLOADED ----
csv_path = "oasis_longitudinal_demographics-8d83e569fa2e2d30.xlsx"
# Read Excel, not CSV
raw = pd.read_excel(csv_path)
# Clean column names
raw.columns = [c.strip().lower().replace(" ", "_") for c in raw.columns]
# Fix sex column name
if "m_f" not in raw.columns and "m/f" in raw.columns:
raw["m_f"] = raw["m/f"]
# Remove missing rows for required ID columns
raw = raw.dropna(subset=["subject_id", "group"])
# Baseline = Visit 1
baseline = raw[raw["visit"] == 1].copy()
# Encode sex (M/F)
sex_map = {"M": 1, "F": 0, "m": 1, "f": 0}
baseline["sex_enc"] = baseline["m_f"].map(sex_map)
# Dementia label map
label_map = {"Nondemented": 0, "Demented": 1, "Converted": 1}
baseline["dementia_label"] = baseline["group"].map(label_map)
# Feature set for visualizations
feature_cols = ["age", "mmse", "educ", "ses", "nwbv", "etiv", "asf", "sex_enc"]
# Final modeling dataframe
df_model = baseline[feature_cols + ["dementia_label"]].dropna().copy()
print("df_model created:", df_model.shape)
print("feature_cols:", feature_cols)
print(df_model.head(10))
df_model created: (142, 9)
feature_cols: ['age', 'mmse', 'educ', 'ses', 'nwbv', 'etiv', 'asf', 'sex_enc']
age mmse educ ses nwbv etiv asf sex_enc dementia_label
0 87 27.0 14 2.0 0.696106 1986.55 0.88344 1 0
5 88 28.0 18 3.0 0.709512 1215.33 1.44406 0 0
7 80 28.0 12 4.0 0.711502 1688.58 1.03933 1 0
13 93 30.0 14 2.0 0.697599 1271.51 1.38024 0 0
15 68 27.0 12 2.0 0.806315 1456.60 1.20486 1 1
17 66 30.0 12 3.0 0.768708 1446.66 1.21314 0 1
19 78 29.0 16 2.0 0.747875 1333.37 1.31621 0 0
22 81 30.0 12 4.0 0.715019 1229.72 1.42716 0 0
25 76 21.0 16 3.0 0.696770 1601.89 1.09558 1 1
27 88 25.0 8 4.0 0.659691 1650.60 1.06325 1 1
In [23]:
# ============================================================
# FULL VISUALIZATION SUITE FOR OASIS-2 BASELINE df_model
# ============================================================
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
# Create folder for figures
os.makedirs("figures", exist_ok=True)
# Pretty plotting style
sns.set(style="whitegrid", font_scale=1.3)
# ============================================================
# 1) CLASS BALANCE PLOT
# ============================================================
def plot_class_balance(df):
plt.figure(figsize=(7,5))
sns.countplot(x="dementia_label", data=df, palette="viridis")
plt.title("Class Distribution: Dementia vs Non-Dementia")
plt.xticks([0,1], ["Non-Demented", "Demented/Converted"])
plt.ylabel("Count")
plt.xlabel("")
plt.tight_layout()
plt.savefig("figures/class_balance.png")
plt.show()
plot_class_balance(df_model)
# ============================================================
# 2) FEATURE HISTOGRAMS + KDE PER CLASS
# ============================================================
def plot_feature_histograms(df, features):
for f in features:
plt.figure(figsize=(8,5))
sns.histplot(data=df, x=f, hue="dementia_label", kde=True,
palette="viridis", element="step")
plt.title(f"Distribution of {f} by Dementia Status")
plt.tight_layout()
plt.savefig(f"figures/hist_{f}.png")
plt.show()
plot_feature_histograms(df_model, feature_cols)
# ============================================================
# 3) BOXPLOTS PER FEATURE
# ============================================================
def plot_boxplots(df, features):
for f in features:
plt.figure(figsize=(7,5))
sns.boxplot(x="dementia_label", y=f, data=df, palette="viridis")
plt.title(f"{f} by Dementia Status")
plt.xticks([0,1], ["Non-Demented", "Demented"])
plt.tight_layout()
plt.savefig(f"figures/box_{f}.png")
plt.show()
plot_boxplots(df_model, feature_cols)
# ============================================================
# 4) CORRELATION HEATMAP
# ============================================================
def plot_corr_heatmap(df):
corr = df[feature_cols + ["dementia_label"]].corr()
plt.figure(figsize=(12,10))
sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f")
plt.title("Correlation Heatmap (Baseline Features + Target)")
plt.tight_layout()
plt.savefig("figures/correlation_heatmap.png")
plt.show()
plot_corr_heatmap(df_model)
# ============================================================
# 5) PAIRPLOT (SEABORN)
# ============================================================
sns.pairplot(df_model[["age","mmse","nwbv","etiv","dementia_label"]],
hue="dementia_label", palette="viridis", diag_kind="kde")
plt.savefig("figures/pairplot.png")
plt.show()
# ============================================================
# 6) VIOLIN PLOTS — DISTRIBUTION SHAPE + SPREAD
# ============================================================
def plot_violins(df, features):
for f in features:
plt.figure(figsize=(8,5))
sns.violinplot(x="dementia_label", y=f, data=df, palette="viridis")
plt.title(f"Violin Plot: {f} by Dementia Status")
plt.xticks([0,1], ["Non-Demented", "Demented"])
plt.tight_layout()
plt.savefig(f"figures/violin_{f}.png")
plt.show()
plot_violins(df_model, feature_cols)
# ============================================================
# 7) AGE vs BRAIN VOLUME — KEY BIOMARKER INTERACTION
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="age", y="nwbv",
hue="dementia_label", palette="viridis", s=80)
plt.title("Age vs Normalized Whole-Brain Volume (nWBV)")
plt.tight_layout()
plt.savefig("figures/scatter_age_nwbv.png")
plt.show()
# ============================================================
# 8) MMSE vs BRAIN VOLUME — Cognitive vs Structural Decline
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="mmse", y="nwbv",
hue="dementia_label", palette="viridis", s=80)
plt.title("MMSE vs Brain Volume (nWBV)")
plt.tight_layout()
plt.savefig("figures/scatter_mmse_nwbv.png")
plt.show()
# ============================================================
# 9) EDUCATION vs SES (Sociodemographic patterns)
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="educ", y="ses",
hue="dementia_label", palette="viridis", s=80)
plt.title("Education vs SES by Dementia Status")
plt.tight_layout()
plt.savefig("figures/scatter_educ_ses.png")
plt.show()
# ============================================================
# 10) eTIV vs nWBV — True atrophy visualization
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(data=df_model, x="etiv", y="nwbv",
hue="dementia_label", palette="viridis", s=80)
plt.title("eTIV vs nWBV (Atrophy Indicator)")
plt.tight_layout()
plt.savefig("figures/scatter_etiv_nwbv.png")
plt.show()
print("🎉 All visualizations generated and saved in /figures/")
🎉 All visualizations generated and saved in /figures/
In [24]:
# ============================================================
# BEAUTIFUL MULTI-COLOR VISUALIZATION SUITE (IEEE-READY)
# ============================================================
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
os.makedirs("figures", exist_ok=True)
# Different palettes for variety
palette1 = "viridis"
palette2 = "crest"
palette3 = "rocket"
palette4 = "flare"
palette5 = "magma"
palette6 = "coolwarm"
palette7 = "Spectral"
palette8 = "cubehelix"
palette9 = "icefire"
palette10 = "Set2"
sns.set_context("talk", font_scale=1.15)
# ============================================================
# 1) CLASS BALANCE
# ============================================================
sns.set_style("whitegrid")
plt.figure(figsize=(7,5))
sns.countplot(x="dementia_label", data=df_model, palette=palette7)
plt.title("Class Distribution", fontsize=18)
plt.xticks([0,1], ["Non-Demented", "Demented"], fontsize=14)
plt.ylabel("Count")
plt.tight_layout()
plt.savefig("figures/01_class_balance.png")
plt.show()
# ============================================================
# 2) HISTOGRAMS (Different palette each feature)
# ============================================================
palettes_cycle = [palette1, palette2, palette3, palette4, palette5,
palette6, palette7, palette8]
sns.set_style("ticks")
for i, f in enumerate(feature_cols):
plt.figure(figsize=(8,5))
sns.histplot(
df_model, x=f, hue="dementia_label",
kde=True, element="step", alpha=0.5,
palette=palettes_cycle[i % len(palettes_cycle)]
)
plt.title(f"Distribution of {f}", fontsize=18)
plt.xlabel(f)
plt.tight_layout()
plt.savefig(f"figures/02_hist_{f}.png")
plt.show()
# ============================================================
# 3) BOXPLOTS
# ============================================================
sns.set_style("whitegrid")
for i, f in enumerate(feature_cols):
plt.figure(figsize=(7,5))
sns.boxplot(
x="dementia_label", y=f, data=df_model,
palette=palette10
)
plt.title(f"{f} Comparison", fontsize=18)
plt.xticks([0,1], ["Non-Demented", "Demented"], fontsize=14)
plt.tight_layout()
plt.savefig(f"figures/03_box_{f}.png")
plt.show()
# ============================================================
# 4) VIOLIN PLOTS (High quality)
# ============================================================
sns.set_style("darkgrid")
for f in feature_cols:
plt.figure(figsize=(8,5))
sns.violinplot(
data=df_model,
x="dementia_label", y=f,
palette=palette3, inner="quartile"
)
plt.title(f"Violin Plot: {f}", fontsize=18)
plt.xticks([0,1], ["Non-Demented", "Demented"], fontsize=14)
plt.tight_layout()
plt.savefig(f"figures/04_violin_{f}.png")
plt.show()
# ============================================================
# 5) CORRELATION HEATMAP (advanced)
# ============================================================
sns.set_style("white")
plt.figure(figsize=(12,10))
corr = df_model.corr()
sns.heatmap(
corr, annot=True, cmap=palette9,
linewidths=0.5, linecolor="white",
square=True, cbar_kws={"shrink": 0.8}
)
plt.title("Correlation Matrix (Baseline Features)", fontsize=20)
plt.tight_layout()
plt.savefig("figures/05_corr_heatmap.png")
plt.show()
# ============================================================
# 6) SCATTERPLOT: Age vs Brain Volume
# ============================================================
sns.set_style("ticks")
plt.figure(figsize=(8,6))
sns.scatterplot(
data=df_model,
x="age", y="nwbv",
hue="dementia_label",
palette=palette6, s=120, alpha=0.8, edgecolor="black"
)
plt.title("Age vs nWBV", fontsize=18)
plt.tight_layout()
plt.savefig("figures/06_scatter_age_nwbv.png")
plt.show()
# ============================================================
# 7) SCATTERPLOT: MMSE vs nWBV
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(
data=df_model,
x="mmse", y="nwbv",
hue="dementia_label",
palette=palette8, s=120, alpha=0.85, edgecolor="black"
)
plt.title("MMSE vs Brain Volume", fontsize=18)
plt.tight_layout()
plt.savefig("figures/07_scatter_mmse_nwbv.png")
plt.show()
# ============================================================
# 8) EDUC vs SES
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(
data=df_model, x="educ", y="ses",
hue="dementia_label", palette=palette4,
s=120, edgecolor="black", alpha=0.8
)
plt.title("Education vs SES", fontsize=18)
plt.tight_layout()
plt.savefig("figures/08_scatter_educ_ses.png")
plt.show()
# ============================================================
# 9) eTIV vs nWBV (atrophy)
# ============================================================
plt.figure(figsize=(8,6))
sns.scatterplot(
data=df_model, x="etiv", y="nwbv",
hue="dementia_label", palette=palette2,
s=120, edgecolor="black", alpha=0.8
)
plt.title("eTIV vs Brain Volume (Atrophy)", fontsize=18)
plt.tight_layout()
plt.savefig("figures/09_scatter_etiv_nwbv.png")
plt.show()
print("🎉 ALL MULTI-COLOR VISUALIZATIONS GENERATED SUCCESSFULLY!")
🎉 ALL MULTI-COLOR VISUALIZATIONS GENERATED SUCCESSFULLY!